Skip to content

supports collective training with programs#18392

Merged
gavin1332 merged 5 commits intoPaddlePaddle:developfrom
gavin1332:develop
Jul 2, 2019
Merged

supports collective training with programs#18392
gavin1332 merged 5 commits intoPaddlePaddle:developfrom
gavin1332:develop

Conversation

@gavin1332
Copy link
Collaborator

@gavin1332 gavin1332 commented Jun 28, 2019

test=develop

  1. Since allreduce op has 4 reduce types, We split these four reduce types into four ops;
  2. We also refined the collective op code, e.g. we separated the collective op kernel into CPUKernel and CUDAKernel, and remove the device specified DeviceContext parameter in template as we already knew the target DeviceContext;
  3. We remove the newly added Collective op role to reduce the complexity of program and graph analysis.
  4. Append a new ParamAttr 'distributed' for distributed parameter specification, which means no grad allreduced for these parameters;

@gavin1332 gavin1332 force-pushed the develop branch 3 times, most recently from a2a7b5d to dbe90e1 Compare June 28, 2019 03:00
LLMHao
LLMHao previously approved these changes Jun 28, 2019
Copy link
Member

@guru4elephant guru4elephant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take care of the distributed arguments in ParamAttr

gradient_clip=None,
do_model_average=False):
do_model_average=False,
distributed=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have example for distributed=True?
I wonder is it possible to inference the value of distributed through op type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solution is found, so that we will remove this attribute from ParamAttr. And as it is in the Distributed FC domain, we will handle this parameter in the next pr.


def __init__(self):
Collective.__init__(self)
def __init__(self, nrings=2):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 2?

Copy link
Collaborator Author

@gavin1332 gavin1332 Jun 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The minimum number for parallel comms/streams. As it is no harm for paralleling collective communication in GradAllReduce mode, we prefer this value as the default to 1 which refers to no parallel at all.

Copy link
Member

@guru4elephant guru4elephant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also please remove shard_index_op in this PR

test=develop
@gavin1332 gavin1332 requested a review from chengduoZH June 28, 2019 08:22
@gavin1332
Copy link
Collaborator Author

and also please remove shard_index_op in this PR

as the shard_index_op is also in Distributed FC domain, we have removed it.

test=develop
static_cast<int>(OpRole::kCollective) |
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kCollective) |
static_cast<int>(OpRole::kOptimize),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Op role will increase the complexity of the Graph and Program analysis. I do not recommend adding new o prole.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i agree with u, and i will try to remove newly added op roles. However, as collective ops swap data among trainers, their behaviors is different from backward and optimize ops more or less, i will create a new pr to discuss this topic if necessary. 3ks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@chengduoZH
Copy link
Contributor

Since allreduce op has 4 reduce types, the type of 'sum' has corresponding gradient calculation, but the others, 'max', 'min', 'prod' do not. So we split these four reduce types into four ops;

If the operator in all reduce is sum, does it mean that all reduce is used for gradient aggregation? What logic is this?

As collective ops could be used in backward and optimization phase,

Why? This is an unreasonable assumption.

@gavin1332
Copy link
Collaborator Author

If the operator in all reduce is sum, does it mean that all reduce is used for gradient aggregation? What logic is this?

we are trying to introduce a model parallel strategy to train extreme large classification problem in face recognition, which has up to 10 million classes and the size of the last fc parameter is beyond the GPU memory. So that we have to separate parameter into multiple cards and call collective ops in the forward phase besides the gradient aggregation.

@gavin1332
Copy link
Collaborator Author

gavin1332 commented Jul 1, 2019

As collective ops could be used in backward and optimization phase,

Why? This is an unreasonable assumption.

recent research has been done for algorithms to accelerate deep learning training, namely LocalSGD, which allreduce and average the parameters in the optimization phase instead of allreducing the gradient in the backward phase, which our assumption is based on.

guru4elephant
guru4elephant previously approved these changes Jul 2, 2019
Copy link
Member

@guru4elephant guru4elephant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Should add collective op backward ut if possible.

Copy link
Member

@guru4elephant guru4elephant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@gavin1332 gavin1332 merged commit a873fa8 into PaddlePaddle:develop Jul 2, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants